{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0e95f0df-4d1a-4e2e-92ff-90539bb4c517",
   "metadata": {},
   "source": [
    "# Example 06: CUDA Graphs\n",
    "\n",
    "In this example we demonstrate how to use CUDA graphs through PyTorch with CuTe DSL.\n",
    "The process of interacting with PyTorch's CUDA graph implementation requires exposing PyTorch's CUDA streams to CUTLASS.\n",
    "\n",
    "To use CUDA graphs with Blackwell requires a version of PyTorch that supports Blackwell.\n",
    "This can be obtained through:\n",
    "- The [PyTorch NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)\n",
    "- [PyTorch 2.7 with CUDA 12.8 or later](https://pytorch.org/) (e.g., `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128`)\n",
    "- Building PyTorch directly with your version of CUDA."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "46b8fb6f-9ac5-4a3d-b765-b6476f182bf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import torch for CUDA graphs\n",
    "import torch\n",
    "import cutlass\n",
    "import cutlass.cute as cute\n",
    "# import CUstream type from the cuda driver bindings\n",
    "from cuda.bindings.driver import CUstream\n",
    "# import the current_stream function from torch\n",
    "from torch.cuda import current_stream"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcf5e06e-1f5b-4d72-ad73-9b36efb78ca0",
   "metadata": {},
   "source": [
    "## Kernel Creation\n",
    "\n",
    "We create a kernel which prints \"Hello world\" as well as a host function to launch the kernel.\n",
    "We then compile the kernel for use in our graph, by passing in a default stream.\n",
    "\n",
    "Kernel compilation before graph capture is required since CUDA graphs cannot JIT compile kernels during graph execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0c2a6ca8-98d7-4837-b91f-af769ca8fcd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "@cute.kernel\n",
    "def hello_world_kernel():\n",
    "    \"\"\"\n",
    "    A kernel that prints hello world\n",
    "    \"\"\"\n",
    "    cute.printf(\"Hello world\")\n",
    "\n",
    "@cute.jit\n",
    "def hello_world(stream : CUstream):\n",
    "    \"\"\"\n",
    "    Host function that launches our (1,1,1), (1,1,1) grid in stream\n",
    "    \"\"\"\n",
    "    hello_world_kernel().launch(grid=[1, 1, 1], block=[1, 1, 1], stream=stream)\n",
    "\n",
    "# Grab a stream from PyTorch, this will also initialize our context\n",
    "# so we can omit cutlass.cuda.initialize_cuda_context()\n",
    "stream = current_stream()\n",
    "hello_world_compiled = cute.compile(hello_world, CUstream(stream.cuda_stream))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecc850af-09f8-4a29-9c93-ff31fbb9326f",
   "metadata": {},
   "source": [
    "## Creating and replaying a CUDA Graph\n",
    "\n",
    "We create a stream through torch as well as a graph.\n",
    "When we create the graph we can pass the stream we want to capture to torch. We similarly run the compiled kernel with the stream passed as a CUstream.\n",
    "\n",
    "Finally we can replay our graph and synchronize."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f673e5ae-42bb-44d0-b652-3280606181c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hello world\n",
      "Hello world\n"
     ]
    }
   ],
   "source": [
    "# Create a CUDA Graph\n",
    "g = torch.cuda.CUDAGraph()\n",
    "# Capture our graph\n",
    "with torch.cuda.graph(g):\n",
    "    # Turn our torch Stream into a cuStream stream.\n",
    "    # This is done by getting the underlying CUstream with .cuda_stream\n",
    "    graph_stream = CUstream(current_stream().cuda_stream)\n",
    "    # Run 2 iterations of our compiled kernel\n",
    "    for _ in range(2):\n",
    "        # Run our kernel in the stream\n",
    "        hello_world_compiled(graph_stream)\n",
    "\n",
    "# Replay our graph\n",
    "g.replay()\n",
    "# Synchronize all streams (equivalent to cudaDeviceSynchronize() in C++)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "db76d9c3-7617-4bf2-b326-11982e6803bf",
   "metadata": {},
   "source": [
    "Our run results in the following execution when viewed in NSight Systems:\n",
    "\n",
    "![Image of two hello world kernels run back to back in a CUDA graph](images/cuda_graphs_image.png)\n",
    "\n",
    "We can observe the launch of the two kernels followed by a `cudaDeviceSynchronize()`.\n",
    "\n",
    "Now we can confirm that this minimizes some launch overhead:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3ebe15bf-dc97-42e9-913c-224ecfb472e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n",
      "Hello world\n"
     ]
    }
   ],
   "source": [
    "# Get our CUDA stream from PyTorch\n",
    "stream = CUstream(current_stream().cuda_stream)\n",
    "\n",
    "# Create a larger CUDA Graph of 100 iterations\n",
    "g = torch.cuda.CUDAGraph()\n",
    "# Capture our graph\n",
    "with torch.cuda.graph(g):\n",
    "    # Turn our torch Stream into a cuStream stream.\n",
    "    # This is done by getting the underlying CUstream with .cuda_stream\n",
    "    graph_stream = CUstream(current_stream().cuda_stream)\n",
    "    # Run 2 iterations of our compiled kernel\n",
    "    for _ in range(100):\n",
    "        # Run our kernel in the stream\n",
    "        hello_world_compiled(graph_stream)\n",
    "\n",
    "# Create CUDA events for measuring performance\n",
    "start = torch.cuda.Event(enable_timing=True)\n",
    "end = torch.cuda.Event(enable_timing=True)\n",
    "\n",
    "# Run our kernel to warm up the GPU\n",
    "for _ in range(100):\n",
    "    hello_world_compiled(stream)\n",
    "\n",
    "# Record our start time\n",
    "start.record()\n",
    "# Run 100 kernels\n",
    "for _ in range(100):\n",
    "    hello_world_compiled(stream)\n",
    "# Record our end time\n",
    "end.record()\n",
    "# Synchronize (cudaDeviceSynchronize())\n",
    "torch.cuda.synchronize()\n",
    "\n",
    "# Calculate the time spent when launching kernels in a stream\n",
    "# Results are in ms\n",
    "stream_time = start.elapsed_time(end) \n",
    "\n",
    "# Warmup our GPU again\n",
    "g.replay()\n",
    "# Record our start time\n",
    "start.record()\n",
    "# Run our graph\n",
    "g.replay()\n",
    "# Record our end time\n",
    "end.record()\n",
    "# Synchronize (cudaDeviceSynchronize())\n",
    "torch.cuda.synchronize()\n",
    "\n",
    "# Calculate the time spent when launching kernels in a graph\n",
    "# units are ms\n",
    "graph_time = start.elapsed_time(end)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "12b8151a-46b3-4c99-9945-301f6b628131",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.94% speedup when using CUDA graphs for this kernel!\n"
     ]
    }
   ],
   "source": [
    "# Print out speedup when using CUDA graphs\n",
    "percent_speedup = (stream_time - graph_time) / graph_time\n",
    "print(f\"{percent_speedup * 100.0:.2f}% speedup when using CUDA graphs for this kernel!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
